// Copyright 2024 Mozilla Foundation
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

//
//                   _   _          ___ _      _   ___
//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
//                   \__|_|_||_\_, |___/____/_/ \_\___/
//                             |__/
//
//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].

#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif

#include "mllm/utils/Common.hpp"
#include "mllm/backends/cpu/kernels/common/llamafile/llamafile_sgemm.hpp"
#include "mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp"

#ifdef _MSC_VER
#define NOINLINE __declspec(noinline)
#else
#define NOINLINE __attribute__((__noinline__))
#endif

#if defined(__ARM_NEON) || defined(__AVX512F__)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
#endif

#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)

namespace mllm::cpu {

namespace MLLM_ANONYMOUS_NAMESPACE {

inline float unhalf(mllm_fp16_t d) { return MLLM_FP16_TO_FP32(d); }

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED ARITHMETIC OPERATIONS

#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
#endif  // __SSE__

#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
#endif  // __AVX__

#if defined(__AVX512F__)
inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
#endif  // __AVX512F__

#if defined(__ARM_NEON)
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
#endif  // __ARM_NEON

#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD

/**
 * Computes a * b + c.
 */
template<typename T, typename U>
inline U madd(T a, T b, U c) {
  return add(mul(a, b), c);
}

#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template<>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
  return _mm256_fmadd_ps(a, b, c);
}
#endif
#if defined(__AVX512F__)
template<>
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
  return _mm512_fmadd_ps(a, b, c);
}
#endif
#endif

#if defined(__ARM_FEATURE_FMA)
template<>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
  return vfmaq_f32(c, b, a);
}
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
template<>
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
  return vfmaq_f16(c, b, a);
}
#endif
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM

#if defined(__ARM_NEON)
inline float hsum(float32x4_t x) { return vaddvq_f32(x); }
#endif  // __ARM_NEON

#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
inline float hsum(float16x8_t x) {
  return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x))));
}
#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  x = _mm_add_ps(x, _mm_movehl_ps(x, x));
  x = _mm_add_ss(x, _mm_movehdup_ps(x));
#else
  __m128 t;
  t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
  x = _mm_add_ps(x, t);
  t = _mm_movehl_ps(t, x);
  x = _mm_add_ss(x, t);
#endif
  return _mm_cvtss_f32(x);
}
#endif

#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m256 x) { return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); }
#endif  // __AVX__

#if defined(__AVX512F__)
inline float hsum(__m512 x) { return _mm512_reduce_add_ps(x); }
#endif  // __AVX512F__

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED MEMORY LOADING

template<typename T, typename U>
T load(const U*);

#if defined(__ARM_NEON)
template<>
inline float32x4_t load(const float* p) {
  return vld1q_f32(p);
}
#if !defined(_MSC_VER)
template<>
inline float16x8_t load(const mllm_fp16_t* p) {
  return vld1q_f16((const float16_t*)p);
}
template<>
inline float32x4_t load(const mllm_fp16_t* p) {
  return vcvt_f32_f16(vld1_f16((const float16_t*)p));
}
#endif  // _MSC_VER
#endif  // __ARM_NEON

#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template<>
inline __m128 load(const float* p) {
  return _mm_loadu_ps(p);
}
#endif  // __SSE__

#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template<>
inline __m256 load(const float* p) {
  return _mm256_loadu_ps(p);
}
#endif  // __AVX__

#if defined(__F16C__)
template<>
inline __m256 load(const mllm_fp16_t* p) {
  return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p));
}
#endif  // __F16C__

#if defined(__AVX512F__)
template<>
inline __m512 load(const float* p) {
  return _mm512_loadu_ps(p);
}
template<>
inline __m512 load(const mllm_fp16_t* p) {
  return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p));
}
#endif  // __AVX512F__

////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION

template<int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
 public:
  tinyBLAS(int64_t k, const TA* A, int64_t lda, const TB* B, int64_t ldb, TC* C, int64_t ldc, int ith, int nth,
           const float* bias = nullptr)
      : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), bias(bias) {}

  void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); }

 private:
  NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t mc, nc, mp, np;
    switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
#if VECTOR_REGISTERS == 32
      case 0x55:
        mc = 5;
        nc = 5;
        gemm<5, 5>(m0, m, n0, n);
        break;
      case 0x45:
        mc = 4;
        nc = 5;
        gemm<4, 5>(m0, m, n0, n);
        break;
      case 0x54:
        mc = 5;
        nc = 4;
        gemm<5, 4>(m0, m, n0, n);
        break;
      case 0x44:
        mc = 4;
        nc = 4;
        gemm<4, 4>(m0, m, n0, n);
        break;
      case 0x53:
        mc = 5;
        nc = 3;
        gemm<5, 3>(m0, m, n0, n);
        break;
      case 0x35:
        mc = 3;
        nc = 5;
        gemm<3, 5>(m0, m, n0, n);
        break;
      case 0x43:
        mc = 4;
        nc = 3;
        gemm<4, 3>(m0, m, n0, n);
        break;
#else
      case 0x55:
      case 0x54:
      case 0x53:
      case 0x45:
      case 0x44:
      case 0x43:
        mc = 4;
        nc = 3;
        gemm<4, 3>(m0, m, n0, n);
        break;
      case 0x35:
#endif
      case 0x34:
        mc = 3;
        nc = 4;
        gemm<3, 4>(m0, m, n0, n);
        break;
      case 0x52:
        mc = 5;
        nc = 2;
        gemm<5, 2>(m0, m, n0, n);
        break;
      case 0x33:
        mc = 3;
        nc = 3;
        gemm<3, 3>(m0, m, n0, n);
        break;
      case 0x25:
        mc = 2;
        nc = 5;
        gemm<2, 5>(m0, m, n0, n);
        break;
      case 0x42:
        mc = 4;
        nc = 2;
        gemm<4, 2>(m0, m, n0, n);
        break;
      case 0x24:
        mc = 2;
        nc = 4;
        gemm<2, 4>(m0, m, n0, n);
        break;
      case 0x32:
        mc = 3;
        nc = 2;
        gemm<3, 2>(m0, m, n0, n);
        break;
      case 0x23:
        mc = 2;
        nc = 3;
        gemm<2, 3>(m0, m, n0, n);
        break;
      case 0x51:
        mc = 5;
        nc = 1;
        gemm<5, 1>(m0, m, n0, n);
        break;
      case 0x41:
        mc = 4;
        nc = 1;
        gemm<4, 1>(m0, m, n0, n);
        break;
      case 0x22:
        mc = 2;
        nc = 2;
        gemm<2, 2>(m0, m, n0, n);
        break;
      case 0x15:
        mc = 1;
        nc = 5;
        gemm<1, 5>(m0, m, n0, n);
        break;
      case 0x14:
        mc = 1;
        nc = 4;
        gemm<1, 4>(m0, m, n0, n);
        break;
      case 0x31:
        mc = 3;
        nc = 1;
        gemm<3, 1>(m0, m, n0, n);
        break;
      case 0x13:
        mc = 1;
        nc = 3;
        gemm<1, 3>(m0, m, n0, n);
        break;
      case 0x21:
        mc = 2;
        nc = 1;
        gemm<2, 1>(m0, m, n0, n);
        break;
      case 0x12:
        mc = 1;
        nc = 2;
        gemm<1, 2>(m0, m, n0, n);
        break;
      case 0x11:
        mc = 1;
        nc = 1;
        gemm<1, 1>(m0, m, n0, n);
        break;
      default: return;
    }
    mp = m0 + (m - m0) / mc * mc;
    np = n0 + (n - n0) / nc * nc;
    mnpack(mp, m, n0, np);
    mnpack(m0, m, np, n);
  }

  template<int RM, int RN>
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t ytiles = (m - m0) / RM;
    int64_t xtiles = (n - n0) / RN;
    int64_t tiles = xtiles * ytiles;
    int64_t duty = (tiles + nth - 1) / nth;
    int64_t start = duty * ith;
    int64_t end = start + duty;
    if (end > tiles) end = tiles;
    for (int64_t job = start; job < end; ++job) {
      int64_t ii = m0 + job / xtiles * RM;
      int64_t jj = n0 + job % xtiles * RN;
      D Cv[RN][RM] = {};
      for (int64_t l = 0; l < k; l += KN) {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) {
            Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l), load<V>(B + ldb * (jj + j) + l), Cv[j][i]);
          }
        }
      }

      if (bias) {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = bias[ii + i] + hsum(Cv[j][i]);
        }
      } else {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
        }
      }
    }
  }

  const TA* const A;
  const TB* const B;
  const float* const bias;
  TC* const C;
  const int64_t k;
  const int64_t lda;
  const int64_t ldb;
  const int64_t ldc;
  const int ith;
  const int nth;
};

//////////////////////////////////////////////////////////////////////////////////////////
// QUANT ZERO MATRIX MULTIPLICATION

#if defined(__ARM_FEATURE_DOTPROD)
template<typename TA>
class tinyBLAS_Q0_ARM {
 public:
  tinyBLAS_Q0_ARM(int64_t k, const TA* A, int64_t lda, const block_q8_0* B, int64_t ldb, float* C, int64_t ldc, int ith,
                  int nth, float* bias = nullptr)
      : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), bias(bias) {}

  void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); }

 private:
  NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t mc, nc, mp, np;
    switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
      case 0x33:
        mc = 3;
        nc = 3;
        gemm<3, 3>(m0, m, n0, n);
        break;
      case 0x32:
        mc = 3;
        nc = 2;
        gemm<3, 2>(m0, m, n0, n);
        break;
      case 0x23:
        mc = 2;
        nc = 3;
        gemm<2, 3>(m0, m, n0, n);
        break;
      case 0x22:
        mc = 2;
        nc = 2;
        gemm<2, 2>(m0, m, n0, n);
        break;
      case 0x31:
        mc = 3;
        nc = 1;
        gemm<3, 1>(m0, m, n0, n);
        break;
      case 0x13:
        mc = 1;
        nc = 3;
        gemm<1, 3>(m0, m, n0, n);
        break;
      case 0x21:
        mc = 2;
        nc = 1;
        gemm<2, 1>(m0, m, n0, n);
        break;
      case 0x12:
        mc = 1;
        nc = 2;
        gemm<1, 2>(m0, m, n0, n);
        break;
      case 0x11:
        mc = 1;
        nc = 1;
        gemm<1, 1>(m0, m, n0, n);
        break;
      default: return;
    }
    mp = m0 + (m - m0) / mc * mc;
    np = n0 + (n - n0) / nc * nc;
    mnpack(mp, m, n0, np);
    mnpack(m0, m, np, n);
  }

  template<int RM, int RN>
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t ytiles = (m - m0) / RM;
    int64_t xtiles = (n - n0) / RN;
    int64_t tiles = xtiles * ytiles;
    int64_t duty = (tiles + nth - 1) / nth;
    int64_t start = duty * ith;
    int64_t end = start + duty;
    if (end > tiles) end = tiles;
    for (int64_t job = start; job < end; ++job) {
      int64_t ii = m0 + job / xtiles * RM;
      int64_t jj = n0 + job % xtiles * RN;
      float32x4_t Cv[RN][RM] = {};
      for (int64_t l = 0; l < k; ++l) {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) {
            Cv[j][i] = vmlaq_n_f32(Cv[j][i],
                                   vcvtq_f32_s32(vdotq_s32(vdotq_s32(vdupq_n_s32(0), load_lo(A + lda * (ii + i) + l),
                                                                     load_lo(B + ldb * (jj + j) + l)),
                                                           load_hi(A + lda * (ii + i) + l), load_hi(B + ldb * (jj + j) + l))),
                                   unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d));
          }
        }
      }
      if (bias) {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = bias[ii + i] + hsum(Cv[j][i]);
        }
      } else {
        for (int64_t j = 0; j < RN; ++j) {
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
        }
      }
    }
  }

  inline int8x16_t load_lo(const block_q8_0* b) { return vld1q_s8(b->qs); }

  inline int8x16_t load_hi(const block_q8_0* b) { return vld1q_s8(b->qs + 16); }

  inline int8x16_t load_lo(const block_q4_0* b) {
    return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8));
  }

  inline int8x16_t load_hi(const block_q4_0* b) {
    return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8));
  }

  const TA* const A;
  const block_q8_0* const B;
  const float* const bias;
  float* const C;
  const int64_t k;
  const int64_t lda;
  const int64_t ldb;
  const int64_t ldc;
  const int ith;
  const int nth;
};
#endif  // __ARM_FEATURE_DOTPROD

#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template<typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX {
 public:
  tinyBLAS_Q0_AVX(int64_t k, const TA* A, int64_t lda, const TB* B, int64_t ldb, TC* C, int64_t ldc, int ith, int nth,
                  const float* bias = nullptr)
      : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), bias(bias) {}

  void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); }

 private:
  void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t mc, nc, mp, np;
    switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
#if VECTOR_REGISTERS == 32
      case 0x44:
        mc = 4;
        nc = 4;
        gemm<4, 4>(m0, m, n0, n);
        break;
      case 0x43:
        mc = 4;
        nc = 3;
        gemm<4, 3>(m0, m, n0, n);
        break;
      case 0x34:
        mc = 3;
        nc = 4;
        gemm<3, 4>(m0, m, n0, n);
        break;
      case 0x33:
        mc = 3;
        nc = 3;
        gemm<3, 3>(m0, m, n0, n);
        break;
      case 0x42:
        mc = 4;
        nc = 2;
        gemm<4, 2>(m0, m, n0, n);
        break;
      case 0x24:
        mc = 2;
        nc = 4;
        gemm<2, 4>(m0, m, n0, n);
        break;
#else
      case 0x44:
      case 0x43:
      case 0x42:
        mc = 4;
        nc = 2;
        gemm<4, 2>(m0, m, n0, n);
        break;
      case 0x34:
      case 0x24:
        mc = 2;
        nc = 4;
        gemm<2, 4>(m0, m, n0, n);
        break;
      case 0x33:
#endif
      case 0x32:
        mc = 3;
        nc = 2;
        gemm<3, 2>(m0, m, n0, n);
        break;
      case 0x23:
        mc = 2;
        nc = 3;
        gemm<2, 3>(m0, m, n0, n);
        break;
      case 0x41:
        mc = 4;
        nc = 1;
        gemm<4, 1>(m0, m, n0, n);
        break;
      case 0x22:
        mc = 2;
        nc = 2;
        gemm<2, 2>(m0, m, n0, n);
        break;
      case 0x14:
        mc = 1;
        nc = 4;
        gemm<1, 4>(m0, m, n0, n);
        break;
      case 0x31:
        mc = 3;
        nc = 1;
        gemm<3, 1>(m0, m, n0, n);
        break;
      case 0x13:
        mc = 1;
        nc = 3;
        gemm<1, 3>(m0, m, n0, n);
        break;
      case 0x21:
        mc = 2;
        nc = 1;
        gemm<2, 1>(m0, m, n0, n);
        break;
      case 0x12:
        mc = 1;
        nc = 2;
        gemm<1, 2>(m0, m, n0, n);
        break;
      case 0x11:
        mc = 1;
        nc = 1;
        gemm<1, 1>(m0, m, n0, n);
        break;
      default: return;
    }
    mp = m0 + (m - m0) / mc * mc;
    np = n0 + (n - n0) / nc * nc;
    mnpack(mp, m, n0, np);
    mnpack(m0, m, np, n);
  }

  template<int RM, int RN>
  NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
    int64_t ytiles = (m - m0) / RM;
    int64_t xtiles = (n - n0) / RN;
    int64_t tiles = xtiles * ytiles;
    int64_t duty = (tiles + nth - 1) / nth;
    int64_t start = duty * ith;
    int64_t end = start + duty;
    if (end > tiles) end = tiles;
    for (int64_t job = start; job < end; ++job) {
      int64_t ii = m0 + job / xtiles * RM;
      int64_t jj = n0 + job % xtiles * RN;
      __m256 Cv[RN][RM] = {};
      for (int64_t l = 0; l < k; ++l)
        for (int64_t j = 0; j < RN; ++j)
          for (int64_t i = 0; i < RM; ++i) {
#if defined(__AVX2__)
            __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)),
                                 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l)));
#else
            __m128i ali0 = load0(A + lda * (ii + i) + l);
            __m128i ali1 = load1(A + lda * (ii + i) + l);
            __m128i blj0 = load0(B + ldb * (jj + j) + l);
            __m128i blj1 = load1(B + ldb * (jj + j) + l);

            __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
            __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
            __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
            __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);

            // updot
            const __m128i oneFill = _mm_set1_epi16(1);
            __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
            __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
            __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
#endif
            Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), udTmp, Cv[j][i]);
          }
      if (bias) {
        for (int64_t j = 0; j < RN; ++j)
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = bias[ii + i] + hsum(Cv[j][i]);
      } else {
        for (int64_t j = 0; j < RN; ++j)
          for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
      }
    }
  }

  inline __m256i load(const block_q8_0* b) { return _mm256_loadu_si256((const __m256i*)b->qs); }

  inline __m128i load0(const block_q8_0* b) { return _mm_loadu_si128((const __m128i*)b->qs); }

  inline __m128i load1(const block_q8_0* b) { return _mm_loadu_si128(((const __m128i*)b->qs) + 1); }

  inline __m256i load(const block_q4_0* b) { return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); }

  inline __m128i load0(const block_q4_0* b) {
    const __m128i x = _mm_loadu_si128((const __m128i*)(b->qs));
    return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
  }

  inline __m128i load1(const block_q4_0* b) {
    const __m128i x = _mm_loadu_si128((const __m128i*)(b->qs));
    return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
  }

  inline __m256 updot(__m256i u, __m256i s) {
    __m256i res;
#if defined(__AVXVNNI__)
    res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
#elif defined(__AVX512VNNI__) && defined(__AVX512VL__)
    res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#else
    res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif
    return _mm256_cvtepi32_ps(res);
  }

  static inline __m256i denibble(const uint8_t* p) {
    __m128i x = _mm_loadu_si128((const __m128i*)p);
    return _mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1));
  }

  const TA* const A;
  const TB* const B;
  const float* const bias;
  TC* const C;
  const int64_t k;
  const int64_t lda;
  const int64_t ldb;
  const int64_t ldc;
  const int ith;
  const int nth;
};
#endif  // __AVX__

}  // namespace MLLM_ANONYMOUS_NAMESPACE

/**
 * Performs optimized matrix multiplication on CPU.
 *
 * This subroutine may compute C = Aᵀ * B with column major ordering.
 * Despite its name, this isn't a generalized implementation. Work is
 * only performed when a handwritten kernel is written and available.
 * Otherwise the caller should fall back to a general matmul routine.
 *
 * For example, for single-threaded single-precision GEMM you can say
 *
 *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
 *                     0, 1,
 *                     MLLM_TYPE_F32, MLLM_TYPE_F32, MLLM_TYPE_F32);
 *
 * @param m is rows in `A` and `C`
 * @param n is cols in `B` and `C`
 * @param k is cols in `A` and rows in `B`
 * @param A is first input matrix (always transposed)
 * @param lda is row stride of `A`
 * @param B is second input matrix (never transposed)
 * @param ldb is row stride of `B`
 * @param C is input/output array of output matrices
 * @param ldc is row stride of `C`
 * @param ith is thread id (must be less than `nth`)
 * @param nth is number of threads (must be greater than zero)
 * @param Atype is GGML data type of `A`
 * @param Btype is GGML data type of `B`
 * @param Ctype is GGML data type of `C`
 * @param bias bias pointer
 * @param BiasType check the bias type if is fp32.
 * @return true if this function was able to service the matmul request
 */
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void* A, int64_t lda, const void* B, int64_t ldb, void* C,
                     int64_t ldc, int ith, int nth, DataTypes Atype, DataTypes Btype, DataTypes Ctype, void* bias,
                     DataTypes BiasType) {
  assert(m >= 0);
  assert(n >= 0);
  assert(k >= 0);
  assert(lda >= k);
  assert(ldb >= k);
  assert(ldc >= m);
  assert(nth > 0);
  assert(ith < nth);

  if (bias && BiasType != MLLM_TYPE_F32) return false;

  if (Ctype != MLLM_TYPE_F32) return false;

  switch (Atype) {
    case MLLM_TYPE_F32: {
      if (Btype != MLLM_TYPE_F32) return false;
#if defined(__AVX512F__)
      if (k % 16) return false;
      tinyBLAS<16, __m512, __m512, float, float, float> tb{k,   (const float*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith,
                                                           nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__AVX__) || defined(__AVX2__)
      if (k % 8) return false;
      tinyBLAS<8, __m256, __m256, float, float, float> tb{k,   (const float*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith,
                                                          nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__ARM_NEON)
      if (n < 4) return false;
      if (k % 4) return false;
      tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
          k, (const float*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_F16: {
#if defined(__AVX512F__)
      if (k % 16) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      tinyBLAS<16, __m512, __m512, mllm_fp16_t, float, float> tb{
          k, (const mllm_fp16_t*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
      if (k % 8) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      tinyBLAS<8, __m256, __m256, mllm_fp16_t, float, float> tb{
          k, (const mllm_fp16_t*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
      if (n < 8) return false;
      if (k % 8) return false;
      if (Btype != MLLM_TYPE_F16) return false;
      tinyBLAS<8, float16x8_t, float16x8_t, mllm_fp16_t, mllm_fp16_t, float> tb{
          k, (const mllm_fp16_t*)A, lda, (const mllm_fp16_t*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
      if (k % 4) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      tinyBLAS<4, float32x4_t, float32x4_t, mllm_fp16_t, float, float> tb{
          k, (const mllm_fp16_t*)A, lda, (const float*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_Q8_0: {
      if (Btype != MLLM_TYPE_Q8_0) return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
      tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
          k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__ARM_FEATURE_DOTPROD)
      tinyBLAS_Q0_ARM<block_q8_0> tb{k,   (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, (float*)C, ldc, ith,
                                     nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_Q4_0: {
      if (Btype != MLLM_TYPE_Q8_0) return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
      tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
          k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, (float*)C, ldc, ith, nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#elif defined(__ARM_FEATURE_DOTPROD)
      tinyBLAS_Q0_ARM<block_q4_0> tb{k,   (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, (float*)C, ldc, ith,
                                     nth, (float*)bias};
      tb.matmul(m, n);
      return true;
#else
      return false;
#endif
    }

    default: return false;
  }

  (void)m;
  (void)n;
  (void)k;
  (void)A;
  (void)lda;
  (void)B;
  (void)ldb;
  (void)C;
  (void)ldc;
  (void)ith;
  (void)nth;
  (void)Atype;
  (void)Btype;
  (void)Ctype;
}

bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataTypes Atype, DataTypes Btype, DataTypes Ctype, int64_t lda,
                           int64_t ldb, int64_t ldc) {
  int ith = 0;
  int nth = 1;
  assert(m >= 0);
  assert(n >= 0);
  assert(k >= 0);
  assert(nth > 0);
  assert(ith < nth);

  if (lda < k) return false;
  if (ldb < k) return false;
  if (ldc < m) return false;

  if (Ctype != MLLM_TYPE_F32) return false;

  switch (Atype) {
    case MLLM_TYPE_F32: {
      // return false; //TODO CHECK THIS CALUATE
      if (Btype != MLLM_TYPE_F32) return false;
#if defined(__AVX512F__)
      if (k % 16) return false;
      return true;
#elif defined(__AVX__) || defined(__AVX2__)
      if (k % 8) return false;
      return true;
#elif defined(__ARM_NEON)
      if (n < 4) return false;
      if (k % 4) return false;
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_F16: {
#if defined(__AVX512F__)
      if (k % 16) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
      if (k % 8) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      return true;
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
      if (n < 8) return false;
      if (k % 8) return false;
      if (Btype != MLLM_TYPE_F16) return false;
      return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
      if (k % 4) return false;
      if (Btype != MLLM_TYPE_F32) return false;
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_Q8_0: {
      if (Btype != MLLM_TYPE_Q8_0) return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
      return true;
#elif defined(__ARM_FEATURE_DOTPROD)
      return true;
#else
      return false;
#endif
    }

    case MLLM_TYPE_Q4_0: {
      if (Btype != MLLM_TYPE_Q8_0) return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
      return true;
#elif defined(__ARM_FEATURE_DOTPROD)
      return true;
#else
      return false;
#endif
    }

    default: return false;
  }
}
}  // namespace mllm::cpu